{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "BERT-output.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyMnDskL7tkqqpgIrwBJogXe"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "uFkQG5vh0Y-3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install pytorch_pretrained_bert\n",
        "!apt-get install -y -qq software-properties-common python-software-properties module-init-tools\n",
        "!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null\n",
        "!apt-get update -qq 2>&1 > /dev/null\n",
        "!apt-get -y install -qq google-drive-ocamlfuse fuse\n",
        "from google.colab import auth\n",
        "auth.authenticate_user()\n",
        "from oauth2client.client import GoogleCredentials\n",
        "creds = GoogleCredentials.get_application_default()\n",
        "import getpass\n",
        "!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL\n",
        "vcode = getpass.getpass()\n",
        "!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}\n",
        "!mkdir -p drive\n",
        "!google-drive-ocamlfuse drive\n",
        "import os\n",
        "import sys\n",
        "os.chdir('drive/Colab Notebooks')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "s4CDPO-UzZXA",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# coding: UTF-8\n",
        "import torch\n",
        "import time \n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F \n",
        "import pandas as pd \n",
        "import numpy as np \n",
        "from tqdm import tqdm \n",
        "from torch.utils.data import *\n",
        "from pytorch_pretrained_bert import BertModel, BertTokenizer, BertConfig, BertAdam"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1aFTTq3Yz5X0",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tokenizer = BertTokenizer(vocab_file=\"data/vocab.txt\")"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GkbKOFJM0WTa",
        "colab_type": "code",
        "outputId": "21e129e3-ebc7-4047-806c-1f59f74e90e8",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "input_ids = []\n",
        "input_types = []\n",
        "input_masks = []\n",
        "pad_size = 64\n",
        " \n",
        "with open(\"data/test.txt\") as f:\n",
        "    for i, l in tqdm(enumerate(f)): \n",
        "        x1, x2 = l.strip().split('\\t')\n",
        "        x1 = tokenizer.tokenize(x1)\n",
        "        x2 = tokenizer.tokenize(x2)\n",
        "        tokens = [\"[CLS]\"] + x1 + [\"[SEP]\"] + x2 +[\"[SEP]\"]\n",
        "\n",
        "        ids = tokenizer.convert_tokens_to_ids(tokens)\n",
        "        types = [0] *(len(ids))\n",
        "        masks = [1] * len(ids)\n",
        "        if len(ids) < pad_size:\n",
        "            types = types + [1] * (pad_size - len(ids))\n",
        "            masks = masks + [0] * (pad_size - len(ids))\n",
        "            ids = ids + [0] * (pad_size - len(ids))\n",
        "        else:\n",
        "            types = types[:pad_size]\n",
        "            masks = masks[:pad_size]\n",
        "            ids = ids[:pad_size]\n",
        "        input_ids.append(ids)\n",
        "        input_types.append(types)\n",
        "        input_masks.append(masks)\n",
        "        assert len(ids) == len(masks) == len(types) == pad_size"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "12500it [00:03, 4138.81it/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7AeqhTuY0Wdy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "input_ids_train = np.array(input_ids)\n",
        "input_types_train = np.array(input_types)\n",
        "input_masks_train = np.array(input_masks)\n",
        "\n",
        "BATCH_SIZE = 1\n",
        "train_data = TensorDataset(torch.LongTensor(input_ids_train), \n",
        "                           torch.LongTensor(input_types_train), \n",
        "                           torch.LongTensor(input_masks_train))\n",
        " \n",
        "train_loader = DataLoader(train_data, batch_size=BATCH_SIZE)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MlxK1I4t0Wgh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class Model(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Model, self).__init__()\n",
        "        self.bert = BertModel.from_pretrained(path)\n",
        "        for param in self.bert.parameters():\n",
        "            param.requires_grad = True\n",
        "        self.fc = nn.Linear(768, 2)\n",
        "\n",
        "    def forward(self, x):\n",
        "        context = x[0]\n",
        "        types = x[1]\n",
        "        mask = x[2]\n",
        "        _, pooled = self.bert(context, token_type_ids=types, \n",
        "                              attention_mask=mask, \n",
        "                              output_all_encoded_layers=False)\n",
        "        out = self.fc(pooled)\n",
        "        return out"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "H-ydRZwm0wjy",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "pthpath = 'colab1.pth'\n",
        "model = Model()\n",
        "model = torch.load(pthpath)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ZJcE7YSF0Wkr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "DEVICE = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "def test(model, device, test_loader):\n",
        "    res = []\n",
        "    for batch_idx, (x1,x2,x3) in enumerate(test_loader):\n",
        "        x1,x2,x3 = x1.to(device), x2.to(device), x3.to(device)\n",
        "        with torch.no_grad():\n",
        "            y_ = model([x1,x2,x3])\n",
        "        pred = y_.max(-1, keepdim=True)[1]\n",
        "        res.append(pred)\n",
        "    return res"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7jmFSzam0Wjf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "raw = test(model,DEVICE,train_loader)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xJnuOaRj1Gny",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "result = []\n",
        "\n",
        "for i in raw:\n",
        "  result.append(i.cpu().item())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CEPOcAtU4uVc",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "s = ''\n",
        "for i in r:\n",
        "    s = s + str(i) + '\\n'"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8gjJYbvV4PJ-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "f = open('qaq.txt','w')\n",
        "f.write(s)\n",
        "f.close()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
